import sys
import os
sys.path.append('../../../')  # replace this with root directory
import PIL.Image as Image
import torch
import random
from torch.utils.data import Dataset
import numpy as np
import re
from tqdm import tqdm
import copy
import re
import cv2 as cv
import glob
import argparse

def get_pixel_xy_from_image(image):
    xs, ys = np.where(image != 0)[:2]
    mean = np.array([np.mean(xs), np.mean(ys)])
    if len(xs) == 0 or len(ys) == 0: # ys should also be 0 in this case
        return np.ones(2) * -10 # use -10 pixel location to denote missing
    return mean


def get_state_from_images(valid_frame_paths):  # images are ordered from newest to oldest
    images = [cv.imread(path) for path in valid_frame_paths]
    positions = [get_pixel_xy_from_image(image) for image in images]
    velocities = [positions[i] - positions[i + 1] for i in range(len(positions) - 1)]
    velocities = np.mean(velocities, axis=0)
    state = np.concatenate([positions[0], velocities])
    return state
    

non_object_keys = ["ITR", "Done", "Reward", "VALID_NAMES", "TRACE", "Action"]
STATE_DIM = 4

def process_state_from_dataset(obj_data, base_folder, frame_stack=2):
    new_obj_data = []

    for itr in tqdm(range(len(obj_data))):

        curr_obj_dict = obj_data[itr]
        for n in curr_obj_dict.keys():
            if n not in non_object_keys:
                curr_obj_dict[n] = curr_obj_dict[n][:STATE_DIM]
        new_object_dict = copy.deepcopy(curr_obj_dict)
        frame_number = itr
        folder_number = frame_number // 2000

        obj_files = glob.glob(os.path.join(base_folder, str(folder_number), f'state{frame_number}_*.png'))
        obj_names = [re.search(f'state{frame_number}_(.*).png', obj_file).group(1) for obj_file in obj_files]
        # print(obj_names)
        for obj_name in obj_names:
            assert obj_name in curr_obj_dict, f"Object {obj_name} not found in object data"

            current_image_path = os.path.join(base_folder, str(folder_number), f'state{frame_number}_{obj_name}.png')
            # Check previous frames and add them or pad with the most recent frame
            valid_frame_paths = [current_image_path]
            
            for i in range(1, frame_stack):
                prev_frame_number = frame_number - i
                prev_folder_number = prev_frame_number // 2000
                prev_folder_path = os.path.join(base_folder, str(prev_folder_number))
                done_flag = obj_data[prev_frame_number]['Done']
                prev_file_name = f'state{prev_frame_number}_{obj_name}.png'
                prev_frame_path = os.path.join(prev_folder_path, prev_file_name)
                if prev_frame_number >= 0 and not done_flag and os.path.exists(prev_frame_path):
                    valid_frame_paths.append(prev_frame_path)
                else:
                    # Pad with the most recent valid frame
                    valid_frame_paths.append(valid_frame_paths[-1])
            obj_state = get_state_from_images(valid_frame_paths)
            if np.isnan(np.sum(obj_state)): error
            new_object_dict[obj_name] = obj_state
        
        new_obj_data.append(new_object_dict)

    # dump the pickle file
    import pickle
    with open(os.path.join(base_folder, 'object_state_from_img.pkl'), 'wb') as f:
        pickle.dump(new_obj_data, f)



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Get converted state from data')
    parser.add_argument('--base-dir', default = "",
                        help='base directory to save and load results')
    args = parser.parse_args()

    from Record.file_management import read_obj_dumps
    obj_data = read_obj_dumps(args.base_dir, i=0, rng=-1, filename='object_dumps.txt')
    process_state_from_dataset(obj_data, args.base_dir, frame_stack=2)


    